package com.mlamp.回溯;

import java.util.*;
import java.util.stream.Collectors;

public class 组合总数III {


    public Set<List<Integer>> res = new HashSet<>();

    public static void main(String[] args) {

        组合总数III instance = new 组合总数III();
        List<List<Integer>> lists = instance.combinationSum3(3, 9);
        for (List<Integer> item : lists) {
            System.out.println(Arrays.toString(item.toArray()));
        }
    }


    public List<List<Integer>> combinationSum3(int k, int n) {
        LinkedList<Integer> trace = new LinkedList<>();
        core(k, n, trace, 1);
        return res.stream().collect(Collectors.toList());
    }


    public List<List<Integer>> combinationSum4(int k, int n) {
        LinkedList<Integer> trace = new LinkedList<>();
        core1(k, n, trace, 1);
        return res.stream().collect(Collectors.toList());
    }

    private void core1(int k, int n, LinkedList<Integer> trace, int i) {
        if (trace.size() == k && sum(trace) == n) {
            ArrayList<Integer> integers = new ArrayList<>();
            res.add(integers);
            return;
        }
        for (int j = 0; j <= n; j++) {
            if (trace.contains(i)) continue;
            trace.add(i);
            core1(k, n, trace, i);
            trace.removeLast();
        }
    }


    public void core(int k, int n, LinkedList<Integer> trace, int start) {
        if (trace.size() == k && sum(trace) == n) {
            ArrayList<Integer> integers = new ArrayList<>(trace);
            res.add(integers);
            return;
        }

        for (int i = start; i <= 9; i++) {
            if (trace.contains(i)) continue;
            trace.add(i);
            core(k, n, trace, i);
            trace.removeLast();
        }
    }


    public int sum(LinkedList<Integer> res) {
        int sum = 0;
        for (Integer item : res
        ) {
            sum += item;
        }
        return sum;
    }
}
